我們今天要來寫code測試Local的DynamoDB,我們找一個實際的例子來測試,目前想到會有需要存的就是使用者Google OAuth的token,我們不會希望每次要調用Google Drive的時候,都一直要你登入google進行授權,因此我們會把授權成功後拿到的token資訊存到DynamoDB去。
首先第一步就是開一個資料夾,建立4個檔案,dynamodb.go放一些操作dynamodb通用的一些function,oauth.go放處理剛剛說的token的相關方法,並建立對應的Unit test。

接著到oauth.go建立GoogleOAuthToken的struct,關於如何取得真正的token,我們到真的處理Google OAuth的部份在說明,今天先確保到DynamoDB沒問題就好。
// oauth.go
type GoogleOAuthToken struct {
	PK           string `dynamodbav:"PK"`
	AccessToken  string `dynamodbav:"access_token"`
	TokenType    string `dynamodbav:"token_type"`
	RefreshToken string `dynamodbav:"refresh_token"`
	Expiry       string `dynamodbav:"expiry"`
}
知道我們要操作的結構後,我們回到dynamodb.go,創建TableBasics最為基礎結構
// dynamodb.go
package dynamodb
import (
  "context"
	"errors"
	"fmt"
	"log"
	"github.com/aws/aws-sdk-go-v2/aws"
	"github.com/aws/aws-sdk-go-v2/config"
	"github.com/aws/aws-sdk-go-v2/credentials"
	"github.com/aws/aws-sdk-go-v2/service/dynamodb"
	"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
)
type TableBasics struct {
	DynamoDbClient *dynamodb.Client
	TableName      string
}
func NewTableBasics(tableName string) *TableBasics {
	cfg, err := config.LoadDefaultConfig(context.TODO())
	if err != nil {
		panic(err)
	}
	client := dynamodb.NewFromConfig(cfg)
	return &TableBasics{
		DynamoDbClient: client,
		TableName:      tableName,
	}
}
下面加上CreateLocalClient,建立一個連線到地端的*dynamodb.Client,如果是平時在本機端開發,我們使用這個function得到的Client
// dynamodb.go
func CreateLocalClient(port int) *dynamodb.Client {
	cfg, err := config.LoadDefaultConfig(context.TODO(),
		config.WithRegion("ap-northeast-1"),
		config.WithCredentialsProvider(credentials.StaticCredentialsProvider{
			Value: aws.Credentials{
				AccessKeyID: "dummy", SecretAccessKey: "dummy", SessionToken: "dummy",
				Source: "Hard-coded credentials; values are irrelevant for local DynamoDB",
			},
		}),
	)
	if err != nil {
		panic(err)
	}
	dsn := fmt.Sprintf("http://localhost:%d/", port)
	return dynamodb.NewFromConfig(cfg, func(o *dynamodb.Options) {
		o.BaseEndpoint = aws.String(dsn)
	})
}
接著打開dynamodb_test.go,我們將其作為測試的主要進入口(TestMain),建立一個TableBasics結構,其中TableName存"google-oauth”,再把原本的DynamoDbClient替換成CreateLocalClient產生的,讓他連線到我們昨天開在8000port的local dynamodb,這樣之後其他unit test直接操作初始化好的testTableBasics就好了~
// dynamodb_test.go
package dynamodb
import (
	"os"
	"testing"
)
var testTableBasics *TableBasics
func TestMain(m *testing.M) {
	// google_oauth
	testTableBasics = NewTableBasics("google-oauth")
	// change to local dynamodb
	testTableBasics.DynamoDbClient = CreateLocalClient(8000)
	os.Exit(m.Run())
}
再來回到oauth.go,補上建立table的方法 (從TableBasics拿出TableName,調用DynamoDbClient.CreateTable來建立table)
//oauth.go
package dynamodb
import (
	"context"
	"log"
	"time"
	"github.com/aws/aws-sdk-go-v2/aws"
	"github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
	"github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression"
	"github.com/aws/aws-sdk-go-v2/service/dynamodb"
	"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
)
type GoogleOAuthToken struct {
	PK           string `dynamodbav:"PK"`
	AccessToken  string `dynamodbav:"access_token"`
	TokenType    string `dynamodbav:"token_type"`
	RefreshToken string `dynamodbav:"refresh_token"`
	Expiry       string `dynamodbav:"expiry"`
}
func (basics TableBasics) CreateGoogleOAuthTable() (*types.TableDescription, error) {
	var tableDesc *types.TableDescription
	table, err := basics.DynamoDbClient.CreateTable(context.TODO(), &dynamodb.CreateTableInput{
		AttributeDefinitions: []types.AttributeDefinition{
			{
				AttributeName: aws.String("PK"),
				AttributeType: types.ScalarAttributeTypeS,
			},
		},
		KeySchema: []types.KeySchemaElement{{
			AttributeName: aws.String("PK"),
			KeyType:       types.KeyTypeHash,
		}},
		TableName: aws.String(basics.TableName),
		ProvisionedThroughput: &types.ProvisionedThroughput{
			ReadCapacityUnits:  aws.Int64(10),
			WriteCapacityUnits: aws.Int64(10),
		},
	})
	if err != nil {
		log.Printf("Couldn't create table %v. Here's why: %v\n", basics.TableName, err)
	} else {
		waiter := dynamodb.NewTableExistsWaiter(basics.DynamoDbClient)
		err = waiter.Wait(context.TODO(), &dynamodb.DescribeTableInput{
			TableName: aws.String(basics.TableName)}, 5*time.Minute)
		if err != nil {
			log.Printf("Wait for table exists failed. Here's why: %v\n", err)
		}
		tableDesc = table.TableDescription
	}
	return tableDesc, err
}
接著我們將PK設計為Line的User UUID,因此我們會在碰到PK的地方都加上"LINEID#”的prefix,然後補上對應的CRUD就好哩
//oauth.go
const PK_PREFIX_LINE = "LINEID#"
// Get PK Key (with Line prefix)
func (tok GoogleOAuthToken) GetKey() map[string]types.AttributeValue {
	// Add prefix to PK
	line_userid, err := attributevalue.Marshal(PK_PREFIX_LINE + tok.PK)
	if err != nil {
		panic(err)
	}
	return map[string]types.AttributeValue{"PK": line_userid}
}
新增GoogleOAuthToken
//oauth.go
func (basics TableBasics) AddGoogleOAuthToken(tok GoogleOAuthToken) error {
	tok.PK = PK_PREFIX_LINE + tok.PK
	item, err := attributevalue.MarshalMap(tok)
	if err != nil {
		panic(err)
	}
	_, err = basics.DynamoDbClient.PutItem(context.TODO(), &dynamodb.PutItemInput{
		TableName: aws.String(basics.TableName), Item: item,
	})
	if err != nil {
		log.Printf("Couldn't add item to table. Here's why: %v\n", err)
	}
	return err
}
更新GoogleOAuthToken
//oauth.go
func (basics TableBasics) TxUpdateGoogleOAuthToken(tok GoogleOAuthToken) (*dynamodb.TransactWriteItemsOutput, error) {
	var err error
	var response *dynamodb.TransactWriteItemsOutput
	update := expression.Set(expression.Name("refresh_token"), expression.Value(tok.RefreshToken))
	update.Set(expression.Name("access_token"), expression.Value(tok.AccessToken))
	expr, err := expression.NewBuilder().WithUpdate(update).Build()
	if err != nil {
		log.Printf("Couldn't build expression for update. Here's why: %v\n", err)
	} else {
		twii := &dynamodb.TransactWriteItemsInput{
			TransactItems: []types.TransactWriteItem{
				{
					Update: &types.Update{
						Key:                       tok.GetKey(),
						TableName:                 aws.String(basics.TableName),
						ExpressionAttributeNames:  expr.Names(),
						ExpressionAttributeValues: expr.Values(),
						UpdateExpression:          expr.Update(),
					},
				},
			},
		}
		response, err = basics.DynamoDbClient.TransactWriteItems(context.TODO(), twii)
		if err != nil {
			log.Printf("Couldn't trasnaciton update tok %v. Here's why: %v\n", tok.PK, err)
		}
	}
	return response, err
}
取得GoogleOAuthToken
//oauth.go
func (basics TableBasics) GetGoogleOAuthToken(line_userid string) (GoogleOAuthToken, error) {
	tok := GoogleOAuthToken{PK: line_userid}
	response, err := basics.DynamoDbClient.GetItem(context.TODO(), &dynamodb.GetItemInput{
		Key: tok.GetKey(), TableName: aws.String(basics.TableName),
	})
	if err != nil {
		log.Printf("Couldn't get info about %v. Here's why: %v\n", line_userid, err)
	} else {
		err = attributevalue.UnmarshalMap(response.Item, &tok)
		if err != nil {
			log.Printf("Couldn't unmarshal response. Here's why: %v\n", err)
		}
	}
	return tok, err
}
到這邊基本的操作就寫好了,我們到oauth_test.go來測試一下,執行單個個別的測試,可以配合昨天http://localhost:8001/ 上的GUI看一下效果,確認一下都有正確的執行就沒問題摟~
// oauth_test.go
package dynamodb
import (
	"testing"
	"github.com/stretchr/testify/assert"
)
func TestCreateGoogleOAuthTable(t *testing.T) {
	tableDesc, err := testTableBasics.CreateGoogleOAuthTable()
	t.Log("tableDesc:", tableDesc)
	t.Log("ERROR:", err)
	assert.NoError(t, err, "Expected no error creating table")
	assert.NotNil(t, tableDesc, "Table description should not be nil")
}
func TestAddGoogleOAuthToken(t *testing.T) {
	tok := GoogleOAuthToken{
		PK:           "test1234",
		AccessToken:  "test123",
		TokenType:    "Bearer",
		RefreshToken: "test123",
		Expiry:       "2023-09-24T11:31:54.2936004+08:00",
	}
	err := testTableBasics.AddGoogleOAuthToken(tok)
	if err != nil {
		t.Log("ERROR:", err)
	}
}
func TestTxUpdateGoogleOAuthToken(t *testing.T) {
	tok := GoogleOAuthToken{
		PK:          "test1234",
		AccessToken: "test123456",
		RefreshToken: "test123456",
	
	}
	output, err := testTableBasics.TxUpdateGoogleOAuthToken(tok)
	t.Log("output:", output)
	if err != nil {
		t.Log("ERROR:", err)
	}
}
func TestGetGoogleOAuthToken(t *testing.T) {
	tok, err := testTableBasics.GetGoogleOAuthToken("test1234")
	if err != nil {
		t.Log(err)
	}
	t.Log("Get Token:", tok)
}